{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# DataVec \n",
    "\n",
    "The [DataVec](https://deeplearning4j.org/datavec) library from DL4J is easy to add to the BeakerX kernel, including displaying its tables with the BeakerX interactive table widget.  DataVec is an ETL Library for Machine Learning, including data pipelines, data munging, and wrangling."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%classpath add mvn\n",
    "org.datavec datavec-api 0.9.1\n",
    "org.datavec datavec-local 0.9.1\n",
    "org.datavec datavec-dataframe 0.9.1\n",
    "org.deeplearning4j deeplearning4j-core 0.9.1\n",
    "org.nd4j nd4j-native-platform 0.9.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%import org.nd4j.linalg.api.ndarray.INDArray\n",
    "%import org.datavec.api.split.FileSplit\n",
    "%import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator\n",
    "%import java.nio.file.Paths\n",
    "%import org.nd4j.linalg.factory.Nd4j\n",
    "%import org.datavec.api.transform.TransformProcess\n",
    "%import org.datavec.api.records.reader.impl.csv.CSVRecordReader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import org.datavec.api.transform.schema.Schema\n",
    "\n",
    "inputDataSchema = new Schema.Builder()\n",
    "            //We can for convenience define multiple columns of the same type\n",
    "            .addColumnsString(\"DateString\", \"TimeString\")\n",
    "            //We can define different column types for different types of data:\n",
    "            .addColumnCategorical(\"State\", Arrays.asList(\"GA\",\"VA\",\"IL\",\"MO\",\"IN\",\"KY\",\"MS\",\"LA\",\"AL\",\"TN\",\"OH\",\"NC\",\"MD\",\"CA\",\"AZ\",\"FL\",\"IA\",\"MN\",\"KS\",\"TX\",\"OK\",\"AR\",\"NE\",\"WA\",\"WY\",\"CO\",\"ID\",\"SD\",\"PA\",\"MT\",\"NV\",\"NY\",\"DE\",\"NM\",\"ME\",\"ND\",\"SC\",\"WV\",\"MI\",\"WI\",\"NH\",\"CT\",\"MA\"))\n",
    "            .addColumnsInteger(\"State No\", \"Scale\", \"Injuries\", \"Fatalities\")\n",
    "            //Some columns have restrictions on the allowable values, that we consider valid:\n",
    "            .addColumnsDouble(\"Start Lat\", \"Start Lon\", \"Length\", \"Width\")\n",
    "            .build();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import org.datavec.api.transform.condition.ConditionOp\n",
    "import org.datavec.api.transform.condition.column.CategoricalColumnCondition\n",
    "import org.datavec.api.transform.filter.ConditionFilter\n",
    "\n",
    "transformProcess = new TransformProcess.Builder(inputDataSchema)\n",
    "  //Let's remove some column we don't need\n",
    "  .removeColumns(\"DateString\", \"TimeString\", \"State No\")\n",
    "  //Now, suppose we only want to analyze tornadoes involving NY, MI, IL, MA. Let's filter out\n",
    "  // everything except for those states.\n",
    "  //Here, we are applying a conditional filter. We remove all of the examples that match the condition\n",
    "  // The condition is \"State\" isn't one of {\"NY\", \"MI\", \"IL\", \"MA\"}\n",
    "  .filter(new ConditionFilter(\n",
    "                new CategoricalColumnCondition(\"State\", ConditionOp.NotInSet, new HashSet<>(Arrays.asList(\"NY\", \"WA\")))))\n",
    "  .build();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import org.datavec.local.transforms.TableRecords\n",
    "import jupyter.Displayer;\n",
    "import jupyter.Displayers;\n",
    "\n",
    "//JVM Repr to display table using our widget instead raw string table\n",
    "Displayers.register(org.datavec.dataframe.api.Table.class, new Displayer<org.datavec.dataframe.api.Table>() {\n",
    "      @Override\n",
    "      public Map<String, String> display(org.datavec.dataframe.api.Table table) {\n",
    "        return new HashMap<String, String>() {{\n",
    "          put(MIMEContainer.MIME.HIDDEN, \"\");\n",
    "          List<List<String>> values = new ArrayList<>();\n",
    "          for (int row=0; row<table.rowCount(); row++) {\n",
    "            List<String> rowValues = new ArrayList<>();\n",
    "            for (int column=0; column<table.columnCount(); column++) {              \n",
    "              rowValues.add(table.get(column, row));\n",
    "            }\n",
    "          values.add(rowValues);\n",
    "          }\n",
    "          System.out.println(values);  \n",
    "          tableDis = new TableDisplay(values, table.columnNames(), new ArrayList());\n",
    "          tableDis.display();\n",
    "        }};\n",
    "      }\n",
    "    });\n",
    "\n",
    "outputSchema = transformProcess.getFinalSchema()\n",
    "table = TableRecords.tableFromSchema(outputSchema)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### For purpose of our example, we load data from CSV file and transform it using DataVec"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import org.datavec.api.records.reader.impl.transform.TransformProcessRecordReader\n",
    "\n",
    "writable = []\n",
    "\n",
    "TransformProcessRecordReader tprr = new TransformProcessRecordReader(new CSVRecordReader(0,\",\"), transformProcess)\n",
    "tprr.initialize(new FileSplit(Paths.get(\"../resources/data/tornadoes_2014.csv\").toFile()))\n",
    "\n",
    "// Extract filtered data (omitting null values)\n",
    "while (tprr.hasNext()) {\n",
    "    elem = tprr.next();\n",
    "    if (elem) {\n",
    "      writable.add(elem)\n",
    "    }\n",
    "}\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "// Fill Table with extracted data\n",
    "for (int row=0; row<writable.size; row++) {\n",
    "    for (int col=0; col<outputSchema.numColumns(); col++) {\n",
    "        column = table.column(col);\n",
    "        column.addCell(\"\" + writable[row][col])\n",
    "    }\n",
    "}\n",
    "\n",
    "table"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Example of network which classify two groups of data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import org.datavec.api.records.reader.RecordReader\n",
    "import org.datavec.api.util.ClassPathResource\n",
    "import org.deeplearning4j.eval.Evaluation\n",
    "import org.deeplearning4j.nn.api.OptimizationAlgorithm\n",
    "import org.deeplearning4j.nn.conf.MultiLayerConfiguration\n",
    "import org.deeplearning4j.nn.conf.NeuralNetConfiguration\n",
    "import org.deeplearning4j.nn.conf.Updater\n",
    "import org.deeplearning4j.nn.conf.layers.DenseLayer\n",
    "import org.deeplearning4j.nn.conf.layers.OutputLayer\n",
    "import org.deeplearning4j.nn.multilayer.MultiLayerNetwork\n",
    "import org.deeplearning4j.nn.weights.WeightInit\n",
    "import org.deeplearning4j.optimize.listeners.ScoreIterationListener\n",
    "import org.nd4j.linalg.activations.Activation\n",
    "import org.nd4j.linalg.dataset.DataSet\n",
    "import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction\n",
    "import org.nd4j.linalg.dataset.api.iterator.DataSetIterator\n",
    "\n",
    "def seed = 123;\n",
    "def learningRate = 0.01;\n",
    "def batchSize = 50;\n",
    "def nEpochs = 30;\n",
    "\n",
    "def numInputs = 2;\n",
    "def numOutputs = 2;\n",
    "def numHiddenNodes = 20;\n",
    "\n",
    "def filenameTrain = Paths.get(\"../resources/data/linear_data_train.csv\").toFile();\n",
    "def filenameTest = Paths.get(\"../resources/data/linear_data_eval.csv\").toFile();\n",
    "\n",
    "    //Load the training data:\n",
    "rr = new CSVRecordReader();\n",
    "rr.initialize(new FileSplit(filenameTrain));\n",
    "trainIter = new RecordReaderDataSetIterator(rr,batchSize,0,2);\n",
    "\n",
    "//Load the test/evaluation data:\n",
    "rrTest = new CSVRecordReader();\n",
    "rrTest.initialize(new FileSplit(filenameTest));\n",
    "testIter = new RecordReaderDataSetIterator(rrTest,batchSize,0,2);\n",
    "\n",
    "conf = new NeuralNetConfiguration.Builder()\n",
    "  .seed(seed)\n",
    "  .iterations(1)\n",
    "  .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)\n",
    "  .learningRate(learningRate)\n",
    "  .updater(Updater.NESTEROVS)     //To configure: .updater(new Nesterovs(0.9))\n",
    "  .list()\n",
    "  .layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)\n",
    "    .weightInit(WeightInit.XAVIER)\n",
    "    .activation(Activation.RELU)\n",
    "    .build())\n",
    "  .layer(1, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)\n",
    "    .weightInit(WeightInit.XAVIER)\n",
    "    .activation(Activation.SOFTMAX).weightInit(WeightInit.XAVIER)\n",
    "    .nIn(numHiddenNodes).nOut(numOutputs).build())\n",
    "  .pretrain(false)\n",
    "  .backprop(true)\n",
    "  .build();\n",
    "\n",
    "model = new MultiLayerNetwork(conf);\n",
    "model.init();\n",
    "model.setListeners(new ScoreIterationListener(10));  //Print score every 10 parameter updates\n",
    "\n",
    "for(int n = 0; n < nEpochs; n++) {\n",
    " model.fit( trainIter );\n",
    "}\n",
    "\n",
    "print \"Evaluate model....\"\n",
    "eval = new Evaluation(numOutputs);\n",
    "\n",
    "while(testIter.hasNext()){\n",
    "  currentElement = testIter.next();\n",
    "  INDArray features = currentElement.getFeatureMatrix();\n",
    "  INDArray lables = currentElement.getLabels();\n",
    "  INDArray predicted = model.output(features,false);\n",
    "\n",
    "  eval.eval(lables, predicted);\n",
    "}\n",
    "\n",
    "//Print the evaluation statistics\n",
    "print eval.stats()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import org.nd4j.linalg.api.ops.impl.indexaccum.IMax\n",
    "\n",
    "def filenameTrain = Paths.get(\"../resources/data/linear_data_train.csv\").toFile();\n",
    "def filenameTest = Paths.get(\"../resources/data/linear_data_eval.csv\").toFile();\n",
    "def extractDataFromND(features, labels) {\n",
    "    def classesOfPoints = [:]\n",
    "    nRows = features.rows()\n",
    "    nClasses = labels.columns()\n",
    "    INDArray argMax = Nd4j.getExecutioner().exec(new IMax(labels), 1);\n",
    "    \n",
    "    for (int i=0; i<features.rows(); i++) {\n",
    "        int classIdx = (int)argMax.getDouble(i);\n",
    "        classesOfPoints << [[features.getDouble(i, 0), features.getDouble(i, 1)]: classIdx]\n",
    "    }\n",
    "    \n",
    "    return classesOfPoints\n",
    "}\n",
    "\n",
    "rr.initialize(new FileSplit(filenameTrain))\n",
    "rr.reset()\n",
    "trainIter = new RecordReaderDataSetIterator(rr, 1000, 0, 2)\n",
    "ds = trainIter.next();\n",
    "rawTrainData = extractDataFromND(ds.getFeatures(), ds.getLabels())\n",
    "\n",
    "rrTest.initialize(new FileSplit(filenameTest))\n",
    "rrTest.reset();\n",
    "testIter = new RecordReaderDataSetIterator(rrTest,500,0,2);\n",
    "ds = testIter.next();\n",
    "INDArray testPredicted = model.output(ds.getFeatures());\n",
    "\n",
    "rawPredictedData = extractDataFromND(ds.getFeatures(), testPredicted)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot = new Plot(title: \"Training data Plot\")\n",
    "plot.setXBound([0.0, 1.0])\n",
    "plot.setYBound([-0.2, 0.8])\n",
    "\n",
    "rawTrainData.each{k, v -> \n",
    "    if (v==0) {\n",
    "       plot << new Points(x: [k[0]], y: [k[1]], color: Color.orange)\n",
    "    } else {\n",
    "       plot << new Points(x: [k[0]], y: [k[1]], color: Color.red)\n",
    "    }\n",
    "}\n",
    "\n",
    "plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot = new Plot(title: \"Predicted data Plot\")\n",
    "plot.setXBound([0.0, 1.0])\n",
    "plot.setYBound([-0.2, 0.8])\n",
    "\n",
    "rawPredictedData.each{k, v -> \n",
    "    if (v==0) {\n",
    "       plot << new Points(x: [k[0]], y: [k[1]], color: Color.orange)\n",
    "    } else {\n",
    "       plot << new Points(x: [k[0]], y: [k[1]], color: Color.red)\n",
    "    }\n",
    "}\n",
    "\n",
    "plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Groovy",
   "language": "groovy",
   "name": "groovy"
  },
  "language_info": {
   "codemirror_mode": "groovy",
   "file_extension": ".groovy",
   "mimetype": "",
   "name": "Groovy",
   "nbconverter_exporter": "",
   "version": "2.4.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
